import _pickle
import sys
import cv2
import skvideo.io
import skvideo.datasets
import ffmpeg
import os
import pickle
import numpy as np
import time
import shutil

# 读入帧时间序列：指定文件
def input_time():
    with open(saliencyBasePath,'rb') as f:
        try:
            data = pickle.load(f, encoding='bytes')
            return data[:, 0]
        except _pickle.UnpicklingError:
            data = np.load(saliencyBasePath, allow_pickle=True)
            return data[:, 0]


# 提取每个时间戳对应的视频帧
def get_frames_by_times(prefix_file):
    # 输出图片到当前目录video文件夹下
    if not os.path.exists('./framed'):
        os.mkdir('./framed')
        
    outputDirName = './framed/' + prefix_file + '/'
    # try:
    #     shutil.rmtree(outputDirName)
    #     print(f"已经删除之前的文件夹'{outputDirName}'以及其中的所有图片")
    # except OSError:
    #     print(f"原有文件夹'{outputDirName}'不存在")

    if not os.path.exists(outputDirName):
        os.mkdir(outputDirName)
        print(f"新建文件夹'{outputDirName}'")
    else:
        print(f"图片文件夹'{outputDirName}'已经存在")

    # 视频路径以及显著性数据集路径
    video_path = videoBasePath + im_name

    # 读取视频以及显著性数据
    cap = cv2.VideoCapture(video_path)
    # data = pickle.load(open(saliency_path, 'rb'), encoding='bytes')
    fps = cap.get(5)
    print(f'视频帧率是{fps}，总帧数{cap.get(7)}')

    start = time.time()
    timestamp = input_time()
    frameId = 1
    while True:
        if frameId == len(timestamp) - 20:
            print(f'have get {frameId - 1} frames. break')
            break

        cap.set(cv2.CAP_PROP_POS_MSEC, round(timestamp[frameId-1]))
        res, image = cap.read()
        cv2.imwrite(outputDirName + f'{frameId}.jpg', cv2.resize(image, (160, 90)))
        frameId += 1
    
    end = time.time()
    # print(f'{i}张图片 提取结束，总用时{end - start}')
    print(f'{frameId-1}张图片提取结束 --> resize to (160, 90)，总用时{end - start}')
    cap.release()


# 提取每个时间戳对应的视频帧
def get_frames_by_nums(prefix_file):
    # 输出图片到当前目录video文件夹下
    if not os.path.exists('./framed'):
        os.mkdir('./framed')

    outputDirName = './framed/' + prefix_file + '/'
    if not os.path.exists(outputDirName):
        os.mkdir(outputDirName)
        print(f"新建文件夹'{outputDirName}'")
    else:
        print(f"图片文件夹'{outputDirName}'已经存在")

    # 视频路径以及显著性数据集路径
    video_path = videoBasePath + im_name
    # 读取视频以及显著性数据
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(5)
    print(f'视频帧率是{fps}，总帧数{cap.get(7)}')

    start = time.time()
    times = 0
    timestamp = input_time()
    frameId = 1
    while True:
        res, image = cap.read()
        times += 1

        if frameId-1 == len(timestamp) or image is None:
            print(f'have get {frameId - 1} frames. break res:', res)
            break

        t = timestamp[frameId - 1]
        if times == round(fps * t) or t == 0:
            # cv2.imwrite(outputDirName + f'{frameId}.jpg', cv2.resize(image, (640, 360)))
            cv2.imwrite(outputDirName + f'{frameId}.jpg', image)
            frameId += 1

    end = time.time()
    print(f'{frameId - 1}张图片 提取结束，总用时{end - start}')
    # print(f'{frameId - 1}张图片提取结束 --> resize to (160, 90)，总用时{end - start}')
    cap.release()


# resize frame to (160, 90)
def resize():
    img_path = 'D:/VR_project/saliency-convnet-7.31/framed/1-7-Cooking Battle/'
    img_resize_path = 'D:/VR_project/saliency-convnet-7.31/framed/Cooking_resize_-19/'
    img_list = os.listdir(img_path)
    img_list.sort(key=lambda x: int(x[:-4]))
    i = 1
    length = len(img_list)
    for img in img_list:
        image = cv2.imread(img_path + img)
        if image.shape == (360, 640, 3):
            pass
        else:
            image = cv2.resize(image, (640, 360))
        cv2.imwrite(img_resize_path+str(i)+'.jpg', image)
        i += 1
        print(f'\r[{i}]/[{length}]', end=' ')
    print(f'\nresize {i} images done!')

def read_frame_by_num(in_file, frame_num):
    """
    指定帧数读取任意帧
    """
    out, err = (
        ffmpeg.input(in_file)
            .filter('select', 'gte(n,{})'.format(frame_num))
            .output('pipe:', vframes=1, format='image2', vcodec='mjpeg')
            .run(capture_stdout=True)
    )

    return out

def read_frame_by_time(in_file, time):
    """
    指定时间节点读取任意帧
    """
    out, err = (
        ffmpeg.input(in_file, ss=time)
              .output('pipe:', vframes=1, format='image2', vcodec='mjpeg')
              .run(capture_stdout=True)
    )
    return out

def get_frames_by_time_ffmpeg(prefix_file):
    # 输出图片到当前目录video文件夹下
    if not os.path.exists('./framed'):
        os.mkdir('./framed')

    outputDirName = './framed/' + prefix_file + '/'

    if not os.path.exists(outputDirName):
        os.mkdir(outputDirName)
        print(f"新建文件夹'{outputDirName}'")
    else:
        print(f"图片文件夹'{outputDirName}'已经存在")

    # 视频路径以及显著性数据集路径
    video_path = videoBasePath + im_name
    # 读取视频以及显著性数据
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(5)
    print(f'视频帧率是{fps}，总帧数{cap.get(7)}')

    start = time.time()
    timestamp = input_time()
    # timestamp = np.loadtxt('./saliency_time/[Cooking]saliency_time.txt')
    frameId = 1
    for t in timestamp:
        out = read_frame_by_time(file_path, t)
        image_array = np.asarray(bytearray(out), dtype="uint8")
        image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
        # cv2.imwrite(outputDirName + f'{frameId}.jpg', cv2.resize(image, (160, 90)))
        cv2.imwrite(outputDirName + f'{frameId}.jpg', image)
        frameId += 1

    end = time.time()
    # print(f'{i}张图片 提取结束，总用时{end - start}')
    print(f'{frameId - 1}张图片提取结束 --> resize to (160, 90)，总用时{end - start}')
    cap.release()


def get_video_info(in_file):
    """
    获取视频基本信息
    """
    try:
        probe = ffmpeg.probe(in_file)
        video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
        if video_stream is None:
            print('No video stream found', file=sys.stderr)
            sys.exit(1)
        return video_stream
    except ffmpeg.Error as err:
        print(str(err.stderr, encoding='utf8'))
        sys.exit(1)

def sample():
    source_path = 'D:/VR_project/train-saliency/framed/Rhinos_resize/'
    des_path = 'D:/VR_project/train-saliency/framed/sample/'
    sal_path = 'D:/VR_project/PanoSaliency/data/saliency_ds2_topic8'
    step = 60
    name_list = os.listdir(source_path)
    name_list.sort(key=lambda x: int(x[:-4]))

    try:
        data = pickle.load(open(sal_path, 'rb'), encoding='bytes')
    except _pickle.UnpicklingError:
        data = np.load(sal_path, allow_pickle=True)

    print(len(data))
    i = 3
    baseId = 1279
    sample_sal = []
    while i * step < len(name_list):
        sample_frame = cv2.imread(source_path+name_list[i*step])
        sample_sal.append(data[i*step - 1][2])
        cv2.imwrite(des_path+f'{baseId}.jpg', sample_frame)
        if i * step >= 690:
             step = 25
        # cv2.imwrite(des_path+f'{baseId}.jpg', cv2.resize(sample_frame, (640, 360)))
        i += 1
        baseId += 1
    np.save('./framed/9-sample_sal_rhinos.npy', np.array(sample_sal))
    print(f"sample {i-3} frames finish for {source_path}, {baseId}")


def combinate():
    base = './framed/'
    sal_list = ['1-sample_sal_skiing.npy', '2-sample_sal_cooking.npy', '3-sample_sal_conan1.npy',
                '4-sample_sal_conan2.npy', '5-sample_sal_alien.npy', '6-sample_sal_surfing.npy',
                '7-sample_sal_war.npy', '8-sample_sal_football.npy', '9-sample_sal_rhinos.npy']

    x = np.load(base+sal_list[0])
    size = x.shape[0]
    print(size)
    for i in range(1, len(sal_list)):
        y = np.load(base+sal_list[i])
        size += y.shape[0]
        print(y.shape[0])
        x = np.concatenate([x, y])
    np.save('./framed/dataset2.npy', x)
    print(x.shape, size)

if __name__ == '__main__':
    resize()
    # sample()
    # combinate()
    exit()

    # 视频文件以及显著性数据根路径
    # videoBasePath = 'D:/VR_project/train-saliency/videos/'
    videoBasePath = 'D:/VR_project/LiveDeep_All/videos/'
    saliencyBasePath = 'D:/VR_project/PanoSaliency/data/saliency_ds2_topic8'
    # [1-1-Conan Gore Fly.mp4, 1-2-Front.mp4, 1-3-360 Google Spotlight Stories_ HELP.mp4,
    # 1-4-Conan Weird Al.mp4, 1-5-TahitiSurf.mp4, 1-6-Falluja.mp4,
    # 1-7-Cooking Battle.mp4, 1-8-Football.mp4, 1-9-Rhinos.mp4]
    im_name = '1-9-Rhinos.mp4'

    # 也无法读取所有帧
    file_path = videoBasePath + im_name
    video_info = get_video_info(file_path)
    total_frames = int(video_info['nb_frames'])
    print('总帧数：' + str(total_frames))

    # out = read_frame_by_num(file_path, 5900)
    # # out = read_frame_by_time(file_path, 100)
    # image_array = np.asarray(bytearray(out), dtype="uint8")
    # image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
    # cv2.imshow('frame', image)
    # cv2.imwrite('test.jpg', image)
    # cv2.waitKey(0)
    # exit(0)

    prefix_file = im_name.split('.')[0]
    suffix_file = im_name.split('.')[-1]
    print(im_name, prefix_file, suffix_file)

    if suffix_file in ['mp4', 'mkv', 'avi', 'webm']:
        print(f'----------提取"{suffix_file}"格式视频帧----------')
        get_frames_by_nums(prefix_file)
        # get_frames_by_time_ffmpeg(prefix_file)
    else:
        pass
